import numpy as np
import matplotlib.pyplot as plt


ds = [3, 50] # [3, 10, 20, 50 ,100]
n_samples = [10,100,1000,10000]

L_ssw = np.zeros((len(ds), 4, 20))
L_w = np.zeros((len(ds), 4, 20))

for j, d in enumerate(ds):
    L_ssw[j] = np.loadtxt("./ssw_sample_d"+str(d), delimiter=",")
    L_w[j] = np.loadtxt("./w_sample_d"+str(d), delimiter=",")
    
    
#fig = plt.figure() #figsize=(18,6))

fig = plt.figure(figsize=(6,3))

for j, d in enumerate(ds):
    # np.savetxt("./ssw_sample", L[j], delimiter=",")
    # np.savetxt("./w_sample", L_w[j], delimiter=",")

    if d==3 or d==50:
        m = np.mean(L_ssw[j], axis=-1)
        s = np.std(L_ssw[j], axis=-1)

        plt.loglog(n_samples, m, label=r"$SSW_2^2$, d="+str(d))
        plt.fill_between(n_samples, m-s, m+s, alpha=0.5)

        m_w = np.mean(L_w[j], axis=-1)
        s_w = np.std(L_w[j], axis=-1)

        plt.loglog(n_samples, m_w, label=r"$W_2^2$, d="+str(d))
        plt.fill_between(n_samples, m_w-s_w, m_w+s_w, alpha=0.5)

plt.xlabel("Number of Samples", fontsize=13)
plt.ylabel("Distances", fontsize=13)
plt.legend(fontsize=13)
plt.savefig("./Comparison_Sample_Complexity_.pdf", format="pdf", bbox_inches="tight")
plt.show()


#plt.xlabel("Number of samples in each distribution", fontsize=13)
#plt.ylabel("Seconds", fontsize=13)
    #plt.yscale("log")
    #plt.xscale("log")
    
#plt.legend(fontsize=13, bbox_to_anchor=(0,1.02,1,0.2), loc="lower left", ncol=2)
#plt.title("Computational Time", fontsize=13)
#plt.grid(True)
#plt.savefig("./Comparison_SW_W2.pdf", format="pdf", bbox_inches="tight")
#plt.show()
